import mma
import pandas as pd
import gudhi as gd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from gudhi.point_cloud.timedelay import TimeDelayEmbedding
from sklearn.neighbors import KernelDensity, KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import GridSearchCV
from scipy.io import loadmat
from random import choice
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from os import chdir, getcwd,walk
from os.path import expanduser
from joblib import Parallel, delayed, cpu_count
from xgboost import XGBClassifier
import networkx as nx
from sklearn.model_selection import StratifiedKFold as sKFold
from DTM_filtration import *
from multiprocessing import Pool



DATASET_PATH = expanduser("~/Datasets/")




def get_graph_dataset(dataset, label_col = -2):
	label_col = -2 if (dataset in ["IMDB-MULTI","REDDIT5K","COX2"]) else label_col
	path = DATASET_PATH + dataset  +"/mat/"
	ppties = []
	gs = []
	for root, dir, files in walk(path):
		for file in files:
			ppty = []
			is_last_int = False
			for char in file:
				if not(char.isnumeric()):
					if is_last_int:
						ppty[-1] = (int)(ppty[-1])
					is_last_int = False
				else:
					if is_last_int:
						ppty[-1] += char
					else:
						ppty.append(char)
						is_last_int = True
			ppties.append(ppty)
			adj_mat = np.array(loadmat(path + file)['A'])
			gs.append(nx.Graph(adj_mat))
	labels = np.array(ppties)[:,label_col]
	return gs, labels

def constant_estimator_performance(labels):
	labels = LabelEncoder().fit_transform(labels)
	nclasses = len(LabelEncoder().fit(labels).classes_)
	tmp = [0]*nclasses
	for l in labels:
		tmp[l]+=1
	print("Constant estimator accuracy :", np.round(np.max(tmp)/len(labels),decimals=2))
	print("Number of graphs :", len(labels))


def get_UCR_dataset(dataset = "Coffee", test = False):
	dataset_path = DATASET_PATH +"UCR/"+ dataset + "/" + dataset
	dataset_path +=  "_TEST.tsv" if test else "_TRAIN.tsv"
	data = np.array(pd.read_csv(dataset_path, delimiter='\t', header=None, index_col=None))
	return data[:,1:-1], LabelEncoder().fit_transform(data[:,0])


def compute_img(simplextree:mma.SimplexTreeMulti, **kwargs):
	## ARGS PARSING
	maxH = kwargs.get("maxH", 0)
	dimension = kwargs.get("dimension", 0)
	dimensions = [dimension] if kwargs.get("dimension") else kwargs.get("dimensions", range(maxH+1)) 
	precision = kwargs.get("precision", 0.01)
	box = kwargs.get("box", [[0,0], [1,1]])
	resolution = kwargs.get("resolution", [100,100])
	resolutions = kwargs.get("resolutions", [resolution])
	p = kwargs.get("p", 1)
	ps = kwargs.get("ps", [p])
	normalize =  kwargs.get("normalize", False)
	normalizes = kwargs.get("normalizes", [normalize])
	img_bandwidth = kwargs.get("img_bandwidth", 0.1)
	img_bandwidths = kwargs.get("img_bandwidths", [img_bandwidth])
	plot_flag = kwargs.get("plot", False)
	size = kwargs.get("size", 2)
	flatten = kwargs.get("flatten",True)
	cb = kwargs.get("colorbar", True)

	## COMPUTE
	# mod = approx(boundary, filtration, precision=precision, box=box, threshold = False)
	mod = simplextree.persistence_approximation(precision=precision, box=box, threshold=False)
	out = [
		mod.image(dimension = dimension,resolution=resolution, bandwidth = bandwidth, normalize=normalize, p=p, plot=plot_flag,cb=cb, size=size)
		for dimension in dimensions for bandwidth in img_bandwidths for normalize in normalizes for p in ps
	]
	if flatten:
		return np.concatenate(out).flatten()
	else:
		return np.array(out)


def compute_imgs(iterator, get_bf, multithreads = True, **kwargs):
	if multithreads:
		n_jobs = kwargs.get("n_jobs", 5)
		verbose = kwargs.get("verbose", 0)
		with Parallel(n_jobs=n_jobs, verbose=verbose, prefer="threads") as p: 
			out = p(delayed(compute_img)(get_bf(data, **kwargs), **kwargs) for data in tqdm(iterator))
		return out
	X = []
	for data in tqdm(iterator):
		X.append(compute_img(get_bf(data, **kwargs), **kwargs))
	return X

def compute_dump_mod(simplextree:mma.SimplexTreeMulti, save="", **kwargs):
	# boundary, filtration = boundary_filtration
	# mod = approx(boundary, filtration, **kwargs).dump()
	mod = simplextree.persistence_approximation(**kwargs).dump()
	if not save:
		return mod
	else:
		import pickle
		with open(save+".pkl", 'wb') as file:
			pickle.dump(mod, file)
		del mod
		return

def compute_mods(iterator, get_bf, dump=False, save="", **kwargs):
	n_jobs = kwargs.get("n_jobs", cpu_count())
	verbose = kwargs.get("verbose", 0)
	mods = None
	with Parallel(n_jobs=n_jobs, verbose=verbose, prefer="threads") as p:
		if save:
			p(delayed(compute_dump_mod)(get_bf(data, **kwargs), save=save+str(i), **kwargs) for i,data in tqdm(enumerate(iterator), total=len(iterator)))
			return
		else:
			mods = p(delayed(compute_dump_mod)(get_bf(data, **kwargs), **kwargs) for data in tqdm(iterator))
	return mods if dump else [mma.from_dump(mod) for mod in mods]




# Classifiers :

cl_n = { svm.SVC().__class__:"SVM", XGBClassifier().__class__:"XGBoost", RandomForestClassifier().__class__:"RandomForest", DecisionTreeClassifier().__class__:"Tree", AdaBoostClassifier().__class__:"AdaBoost", GaussianProcessClassifier().__class__:"GaussianProcess", KNeighborsClassifier().__class__:"KNeighbors", GridSearchCV(RandomForestClassifier(),{}).__class__:"GridSearchCV",
LinearDiscriminantAnalysis().__class__:"LDA",
QuadraticDiscriminantAnalysis().__class__:"QDA",
GaussianNB().__class__:"NaiveBayes",
}

def kfold_acc(cls,x,y, k=10):
	accuracies = np.zeros((len(cls), k))
	for i,(train_idx, test_idx) in enumerate(tqdm(sKFold(k).split(x,y))):
		for j, cl in enumerate(cls):
			cl.fit(x[train_idx], y[train_idx])
			pred = cl.predict(x[test_idx])
			accuracies[j][i] = accuracy_score(y[test_idx], pred)
	return [f"{cl_n[cl.__class__]} : {np.mean(acc*100).round(decimals=3)}% ±{np.std(acc*100).round(decimals=3)}" for cl,acc in zip(cls, accuracies)]

def acc(cls, xtrain,ytrain, xtest,ytest, **kwargs):
	l=kwargs.get("separate_length", False)
	if l and l > len(xtrain):
		N_variables=len(xtrain) / l
		assert N_variables.is_integer()
		N_variables = (int)(N_variables)
		return [acc(cls, xtrain[i*l:(i+1)*l],ytrain[i*l:(i+1)*l],xtest[i*l:(i+1)*l],ytest[i*l:(i+1)*l], **kwargs) for i in range(N_variables)]
	accuracies = np.array([0.]*len(cls))
	for j,cl in tqdm(enumerate(cls)):
		cl.fit(xtrain,ytrain)
		pred = cl.predict(xtest)
		accuracies[j] = accuracy_score(y_true=ytest, y_pred=pred)
	if kwargs.get("show"):
		for cl,acc in zip(cls, accuracies):
			print(f"{cl_n[cl.__class__]} : {np.mean(acc*100).round(decimals=3)}%")
	return accuracies



# Usual Boundary - filtrations functions 
def bf_alpha_dens(x, **kwargs):
	x = np.unique(x, axis=0)
	ac = gd.AlphaComplex(points = x)
	st = ac.create_simplex_tree(max_alpha_square = kwargs.get("threshold", 0.2)**2)
	y = np.array([ac.get_point(i) for i in range(st.num_vertices())])
	# boundary, F1 = splx2bf(st)
	kde = KernelDensity(kernel=kwargs.get("kde_kernel", "gaussian"), bandwidth=kwargs.get("kde_bandwidth",0.4)).fit(x)
	# F1 = np.array(F1)*kwargs.get("scale",[1]*2)[0]
	F2 = -np.array(kde.score_samples(y))*kwargs.get("scale", [1]*2)[1]
	# F1 -= kwargs.get("translate", [0])[0]
	F2 -= kwargs.get("translate", [0]*2)[1]

	st = mma.SimplexTreeMulti(st, num_parameters=2)
	st.fill_lowerstar(F2, parameter=1)
	return st

